# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from functools import partial
from typing import Union, Sequence, Any, Optional, Callable

import jax
import jax.numpy as jnp

from brainpy.mixin import ReturnInfo
from brainpy import math as bm
from brainpy.context import share
from brainpy.dyn.base import NeuDyn
from brainpy.dyn.utils import get_spk_type
from brainpy.initialize import parameter, variable_
from brainpy.types import Shape, ArrayType

__all__ = [
    'InputGroup',
    'OutputGroup',
    'SpikeTimeGroup',
    'PoissonGroup',
]


class InputGroup(NeuDyn):
    """Input neuron group for place holder.

    Args:
      size: int, tuple of int
      keep_size: bool
      mode: Mode
      name: str
    """

    def __init__(
        self,
        size: Union[int, Sequence[int]],
        sharding: Any = None,
        keep_size: bool = False,
        mode: Optional[bm.Mode] = None,
        name: Optional[str] = None,
    ):
        super().__init__(name=name,
                         sharding=sharding,
                         size=size,
                         keep_size=keep_size,
                         mode=mode)

    def update(self, x):
        return x

    def return_info(self):
        return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros)

    def reset_state(self, batch_or_mode=None, **kwargs):
        pass


class OutputGroup(NeuDyn):
    """Output neuron group for place holder.

    Args:
      size: int, tuple of int
      keep_size: bool
      mode: Mode
      name: str
    """

    def __init__(
        self,
        size: Union[int, Sequence[int]],
        sharding: Any = None,
        keep_size: bool = False,
        mode: Optional[bm.Mode] = None,
        name: Optional[str] = None,
    ):
        super().__init__(name=name,
                         sharding=sharding,
                         size=size,
                         keep_size=keep_size,
                         mode=mode)

    def update(self, x):
        return x

    def return_info(self):
        return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros)

    def reset_state(self, batch_size=None, **kwargs):
        pass


class SpikeTimeGroup(NeuDyn):
    """The input neuron group characterized by spikes emitting at given times.

    >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
    >>> SpikeTimeGroup(2, times=[10, 20])
    >>> # or
    >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
    >>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0])
    >>> # or
    >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
    >>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0])
    >>> # or
    >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
    >>> # at 30 ms, neuron 1 fires.
    >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])

    Parameters::

    size : int, tuple, list
        The neuron group geometry.
    indices : list, tuple, ArrayType
        The neuron indices at each time point to emit spikes.
    times : list, tuple, ArrayType
        The time points which generate the spikes.
    name : str, optional
        The name of the dynamic system.
    """

    def __init__(
        self,
        size: Union[int, Sequence[int]],
        indices: Union[Sequence, ArrayType],
        times: Union[Sequence, ArrayType],
        spk_type: Optional[type] = None,
        name: Optional[str] = None,
        sharding: Optional[Sequence[str]] = None,
        keep_size: bool = False,
        mode: Optional[bm.Mode] = None,
        need_sort: bool = True,
    ):
        super().__init__(size=size,
                         sharding=sharding,
                         name=name,
                         keep_size=keep_size,
                         mode=mode)

        # parameters
        if keep_size:
            raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}')
        if len(indices) != len(times):
            raise ValueError(f'The length of "indices" and "times" must be the same. '
                             f'However, we got {len(indices)} != {len(times)}.')
        self.num_times = len(times)
        self.spk_type = get_spk_type(spk_type, self.mode)

        # data about times and indices
        self.times = bm.asarray(times)
        self.indices = bm.asarray(indices, dtype=bm.int_)
        if need_sort:
            sort_idx = jnp.argsort(self.times.value)
            self.indices.value = self.indices[sort_idx]
            self.times.value = self.times[sort_idx]

        # variables
        self.reset_state(self.mode)

    def reset_state(self, batch_size=None, **kwargs):
        self.i = bm.Variable(bm.asarray(0))
        self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type),
                               self.varshape,
                               batch_size,
                               axis_names=self.sharding,
                               batch_axis_name=bm.sharding.BATCH_AXIS)

    def update(self):
        # self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding)
        self.spike.value = bm.zeros_like(self.spike)
        bm.while_loop(self._body_fun, self._cond_fun, ())
        return self.spike.value

    def return_info(self):
        return self.spike

    # functions
    def _cond_fun(self):
        i = self.i.value
        return bm.logical_and(i < self.num_times, share['t'] >= self.times[i])

    def _body_fun(self):
        i = self.i.value
        if isinstance(self.mode, bm.BatchingMode):
            self.spike[:, self.indices[i]] = True
        else:
            self.spike[self.indices[i]] = True
        self.i += 1


class PoissonGroup(NeuDyn):
    """Poisson Neuron Group.
    """

    def __init__(
        self,
        size: Shape,
        freqs: Union[int, float, jax.Array, bm.Array, Callable],
        keep_size: bool = False,
        sharding: Optional[Sequence[str]] = None,
        spk_type: Optional[type] = None,
        name: Optional[str] = None,
        mode: Optional[bm.Mode] = None,
    ):
        super().__init__(size=size,
                         sharding=sharding,
                         name=name,
                         keep_size=keep_size,
                         mode=mode)

        # parameters
        self.freqs = parameter(freqs, self.num, allow_none=False)
        self.spk_type = get_spk_type(spk_type, self.mode)

        # variables
        self.reset_state(self.mode)

    def update(self):
        spikes = bm.random.rand_like(self.spike) <= (self.freqs * share['dt'] / 1000.)
        spikes = bm.asarray(spikes, dtype=self.spk_type)
        # import jax
        # jax.debug.print('PoissonGroup: freqs = {f}, spikes = {s}', f=self.freqs, s=spikes)
        self.spike.value = spikes
        return spikes

    def return_info(self):
        return self.spike

    def reset_state(self, batch_or_mode=None, **kwargs):
        self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode)
